Drowsiness Detection from Labelled Recorded Videos Using Computer Vision¶

1. Create Data PipeLine for Training¶

1.1 Import Dependencies¶

In [2]:
import pandas as pd
import os
import glob
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

1.2 Read and Pre-process Labels¶

In [3]:
df = pd.read_csv("driver_imgs_list.csv")
df
Out[3]:
classname img
0 c0 img_44733.jpg
1 c0 img_72999.jpg
2 c0 img_25094.jpg
3 c0 img_69092.jpg
4 c0 img_92629.jpg
... ... ...
22419 c9 img_56936.jpg
22420 c9 img_46218.jpg
22421 c9 img_25946.jpg
22422 c9 img_67850.jpg
22423 c9 img_9684.jpg

22424 rows × 2 columns

In [4]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 22424 entries, 0 to 22423
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   classname  22424 non-null  object
 1   img        22424 non-null  object
dtypes: object(2)
memory usage: 350.5+ KB
In [5]:
df.columns
Out[5]:
Index(['classname', 'img'], dtype='object')
In [6]:
maps = {"c0": "safe driving",
"c1": "texting - right",
"c2": "talking on the phone - right",
"c3": "texting - left",
"c4": "talking on the phone - left",
"c5": "operating the radio",
"c6": "drinking",
"c7": "reaching behind",
"c8": "hair and makeup",
"c9": "talking to passenger"}
In [7]:
df["target"] = df["classname"].map(maps)
df
Out[7]:
classname img target
0 c0 img_44733.jpg safe driving
1 c0 img_72999.jpg safe driving
2 c0 img_25094.jpg safe driving
3 c0 img_69092.jpg safe driving
4 c0 img_92629.jpg safe driving
... ... ... ...
22419 c9 img_56936.jpg talking to passenger
22420 c9 img_46218.jpg talking to passenger
22421 c9 img_25946.jpg talking to passenger
22422 c9 img_67850.jpg talking to passenger
22423 c9 img_9684.jpg talking to passenger

22424 rows × 3 columns

In [14]:
df["source_path"] = "imgs/train/"+df["classname"]+"/"+df["img"]
df
Out[14]:
classname img target source_path
0 c0 img_44733.jpg safe driving imgs/train/c0/img_44733.jpg
1 c0 img_72999.jpg safe driving imgs/train/c0/img_72999.jpg
2 c0 img_25094.jpg safe driving imgs/train/c0/img_25094.jpg
3 c0 img_69092.jpg safe driving imgs/train/c0/img_69092.jpg
4 c0 img_92629.jpg safe driving imgs/train/c0/img_92629.jpg
... ... ... ... ...
22419 c9 img_56936.jpg talking to passenger imgs/train/c9/img_56936.jpg
22420 c9 img_46218.jpg talking to passenger imgs/train/c9/img_46218.jpg
22421 c9 img_25946.jpg talking to passenger imgs/train/c9/img_25946.jpg
22422 c9 img_67850.jpg talking to passenger imgs/train/c9/img_67850.jpg
22423 c9 img_9684.jpg talking to passenger imgs/train/c9/img_9684.jpg

22424 rows × 4 columns

In [12]:
targets = pd.get_dummies(df['target'], dtype=int)
targets
Out[12]:
drinking hair and makeup operating the radio reaching behind safe driving talking on the phone - left talking on the phone - right talking to passenger texting - left texting - right
0 0 0 0 0 1 0 0 0 0 0
1 0 0 0 0 1 0 0 0 0 0
2 0 0 0 0 1 0 0 0 0 0
3 0 0 0 0 1 0 0 0 0 0
4 0 0 0 0 1 0 0 0 0 0
... ... ... ... ... ... ... ... ... ... ...
22419 0 0 0 0 0 0 0 1 0 0
22420 0 0 0 0 0 0 0 1 0 0
22421 0 0 0 0 0 0 0 1 0 0
22422 0 0 0 0 0 0 0 1 0 0
22423 0 0 0 0 0 0 0 1 0 0

22424 rows × 10 columns

In [14]:
map_back = {k:v for k,v in enumerate(targets.columns)}
map_back
Out[14]:
{0: 'drinking',
 1: 'hair and makeup',
 2: 'operating the radio',
 3: 'reaching behind',
 4: 'safe driving',
 5: 'talking on the phone - left',
 6: 'talking on the phone - right',
 7: 'talking to passenger',
 8: 'texting - left',
 9: 'texting - right'}
In [15]:
df["target"].value_counts().plot(kind='bar')
Out[15]:
<Axes: >
No description has been provided for this image
In [16]:
paths = df["source_path"]

1.3 Image Processing Functions¶

In [17]:
H = 256
W = 256
def process_one_frameid(path):
    frame = tf.image.decode_png(tf.io.read_file(path))
    frame = tf.image.resize(frame, (H, W))/255 # Reshape and scale for faster processing
    return frame.numpy()
In [15]:
df["source_path"][0]
Out[15]:
'imgs/train/c0/img_44733.jpg'
In [19]:
x = process_one_frameid(df["source_path"][0])
In [20]:
plt.imshow(x)
Out[20]:
<matplotlib.image.AxesImage at 0x7ce420c1a6b0>
No description has been provided for this image

1.4 Split into Train-Test¶

In [16]:
X_train, X_test, y_train, y_test = train_test_split(df["source_path"].values, targets.values, test_size=0.2, stratify=targets)
In [17]:
X_train.shape
Out[17]:
(17939,)
In [18]:
X_train
Out[18]:
array(['imgs/train/c0/img_92742.jpg', 'imgs/train/c1/img_90351.jpg',
       'imgs/train/c3/img_28087.jpg', ..., 'imgs/train/c4/img_36036.jpg',
       'imgs/train/c3/img_86411.jpg', 'imgs/train/c1/img_63767.jpg'],
      dtype=object)
In [24]:
x = process_one_frameid(X_train[0])
In [25]:
plt.imshow(x)
Out[25]:
<matplotlib.image.AxesImage at 0x7ce42272b850>
No description has been provided for this image
In [26]:
y_train[0]
Out[26]:
array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0], dtype=uint8)
In [27]:
def decode_output(single_output):
    single_output = np.array(single_output).argmax()
    return targets.columns[single_output]
In [28]:
decode_output(y_train[0])
Out[28]:
'talking on the phone - left'

1.5 Define Dataset Object for Large Volume Data¶

In [29]:
# Train dataset
x = tf.data.Dataset.from_tensor_slices(X_train)
x = x.map(lambda window : tf.numpy_function(process_one_frameid, [window], tf.float32))
y = tf.data.Dataset.from_tensor_slices(y_train)
dataset = tf.data.Dataset.zip((x,y))
dataset = dataset.batch(32)
dataset = dataset.prefetch(64)
In [30]:
batch = dataset.take(1)
In [31]:
x,y = batch.as_numpy_iterator().next()
In [32]:
x.shape
Out[32]:
(32, 256, 256, 3)
In [33]:
y.shape
Out[33]:
(32, 10)
In [34]:
X_test.shape
Out[34]:
(4485,)
In [35]:
# Validation or Test Dataset
valx = tf.data.Dataset.from_tensor_slices(X_test)
valx = valx.map(lambda window : tf.numpy_function(process_one_frameid, [window], tf.float32))
valy = tf.data.Dataset.from_tensor_slices(y_test)
valdataset = tf.data.Dataset.zip((valx,valy))
valdataset = valdataset.batch(32)
In [36]:
i,o = dataset.as_numpy_iterator().next()
In [37]:
i.shape # face # (batch, height, weight, channels)
Out[37]:
(32, 256, 256, 3)
In [38]:
o.shape # batch, outputs
Out[38]:
(32, 10)

2. Model using functional API of keras using 2D CNN for image Classification¶

2.1 Import Dependencies¶

In [39]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Input, TimeDistributed, MaxPooling2D, concatenate, Conv3D, MaxPooling3D, Add
from tensorflow.keras.models import Model, Sequential
import shutil

2.2 Connect to GPU¶

In [40]:
strategy = tf.distribute.MirroredStrategy()

2.3 Define the Model¶

In [41]:
with strategy.scope():
    image = Input(shape=(H, W, 3), name='input')
    one_frame_model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=(H,W,3))
    one_frame_model.trainable=False
    features = one_frame_model(image)
    features = MaxPooling2D(8)(features)
    flat = Flatten()(features)
    dense = Dense(256, activation='relu')(flat)
    dense = Dense(128, activation='relu')(dense)
    dense = Dense(128, activation='relu')(dense)
    dense = Dense(64, activation='relu')(dense)

    output = Dense(targets.shape[1], activation='softmax')(dense) # 10 classes found in this dataset
    model = Model(image, output)
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss='categorical_crossentropy', metrics=['accuracy', tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall')])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
9406464/9406464 [==============================] - 0s 0us/step

2.4 Display Model and Model Summary¶

In [42]:
tf.keras.utils.plot_model(model, show_shapes=True)
Out[42]:
No description has been provided for this image
In [43]:
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input (InputLayer)          [(None, 256, 256, 3)]     0         
                                                                 
 mobilenetv2_1.00_224 (Funct  (None, 8, 8, 1280)       2257984   
 ional)                                                          
                                                                 
 max_pooling2d (MaxPooling2D  (None, 1, 1, 1280)       0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 256)               327936    
                                                                 
 dense_1 (Dense)             (None, 128)               32896     
                                                                 
 dense_2 (Dense)             (None, 128)               16512     
                                                                 
 dense_3 (Dense)             (None, 64)                8256      
                                                                 
 dense_4 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 2,644,234
Trainable params: 386,250
Non-trainable params: 2,257,984
_________________________________________________________________

2.5 Create Logs Directory for TensorBoard¶

In [44]:
try:
    shutil.rmtree("./logs")
except:
    pass
os.mkdir("./logs")
In [45]:
tfboard = tf.keras.callbacks.TensorBoard("./logs")
model_check = tf.keras.callbacks.ModelCheckpoint("./model.h5", save_best_only=True, save_weights_only=False, monitor='val_accuracy', mode='max')

2.6 Train the Model¶

In [46]:
model.fit(dataset, epochs=100, validation_data=valdataset, callbacks=[tfboard, model_check])
Epoch 1/100
561/561 [==============================] - 264s 436ms/step - loss: 2.3153 - accuracy: 0.1796 - precision: 0.1897 - recall: 0.0021 - val_loss: 2.0743 - val_accuracy: 0.2707 - val_precision: 0.5000 - val_recall: 2.2297e-04
Epoch 2/100
561/561 [==============================] - 165s 295ms/step - loss: 1.8982 - accuracy: 0.3730 - precision: 0.8500 - recall: 0.0114 - val_loss: 1.7481 - val_accuracy: 0.4448 - val_precision: 0.9389 - val_recall: 0.0274
Epoch 3/100
561/561 [==============================] - 155s 276ms/step - loss: 1.5636 - accuracy: 0.5272 - precision: 0.9171 - recall: 0.0857 - val_loss: 1.4144 - val_accuracy: 0.5781 - val_precision: 0.9161 - val_recall: 0.1559
Epoch 4/100
561/561 [==============================] - 153s 273ms/step - loss: 1.2510 - accuracy: 0.6461 - precision: 0.9224 - recall: 0.2427 - val_loss: 1.1329 - val_accuracy: 0.6776 - val_precision: 0.9262 - val_recall: 0.3331
Epoch 5/100
561/561 [==============================] - 154s 275ms/step - loss: 0.9982 - accuracy: 0.7270 - precision: 0.9242 - recall: 0.4168 - val_loss: 0.9205 - val_accuracy: 0.7465 - val_precision: 0.9196 - val_recall: 0.4923
Epoch 6/100
561/561 [==============================] - 155s 275ms/step - loss: 0.8098 - accuracy: 0.7841 - precision: 0.9295 - recall: 0.5630 - val_loss: 0.7622 - val_accuracy: 0.7967 - val_precision: 0.9265 - val_recall: 0.6132
Epoch 7/100
561/561 [==============================] - 156s 279ms/step - loss: 0.6697 - accuracy: 0.8266 - precision: 0.9363 - recall: 0.6659 - val_loss: 0.6448 - val_accuracy: 0.8263 - val_precision: 0.9281 - val_recall: 0.6874
Epoch 8/100
561/561 [==============================] - 154s 274ms/step - loss: 0.5660 - accuracy: 0.8536 - precision: 0.9404 - recall: 0.7364 - val_loss: 0.5576 - val_accuracy: 0.8459 - val_precision: 0.9304 - val_recall: 0.7427
Epoch 9/100
561/561 [==============================] - 154s 275ms/step - loss: 0.4864 - accuracy: 0.8760 - precision: 0.9453 - recall: 0.7864 - val_loss: 0.4887 - val_accuracy: 0.8658 - val_precision: 0.9344 - val_recall: 0.7817
Epoch 10/100
561/561 [==============================] - 154s 275ms/step - loss: 0.4233 - accuracy: 0.8926 - precision: 0.9496 - recall: 0.8219 - val_loss: 0.4350 - val_accuracy: 0.8800 - val_precision: 0.9366 - val_recall: 0.8136
Epoch 11/100
561/561 [==============================] - 154s 274ms/step - loss: 0.3729 - accuracy: 0.9076 - precision: 0.9545 - recall: 0.8496 - val_loss: 0.3909 - val_accuracy: 0.8932 - val_precision: 0.9407 - val_recall: 0.8386
Epoch 12/100
561/561 [==============================] - 165s 294ms/step - loss: 0.3311 - accuracy: 0.9201 - precision: 0.9583 - recall: 0.8705 - val_loss: 0.3540 - val_accuracy: 0.9043 - val_precision: 0.9437 - val_recall: 0.8562
Epoch 13/100
561/561 [==============================] - 153s 273ms/step - loss: 0.2961 - accuracy: 0.9280 - precision: 0.9614 - recall: 0.8875 - val_loss: 0.3234 - val_accuracy: 0.9126 - val_precision: 0.9476 - val_recall: 0.8751
Epoch 14/100
561/561 [==============================] - 154s 274ms/step - loss: 0.2662 - accuracy: 0.9364 - precision: 0.9649 - recall: 0.9024 - val_loss: 0.2972 - val_accuracy: 0.9226 - val_precision: 0.9512 - val_recall: 0.8865
Epoch 15/100
561/561 [==============================] - 154s 275ms/step - loss: 0.2407 - accuracy: 0.9439 - precision: 0.9673 - recall: 0.9144 - val_loss: 0.2745 - val_accuracy: 0.9280 - val_precision: 0.9549 - val_recall: 0.8972
Epoch 16/100
561/561 [==============================] - 154s 274ms/step - loss: 0.2185 - accuracy: 0.9502 - precision: 0.9706 - recall: 0.9252 - val_loss: 0.2556 - val_accuracy: 0.9322 - val_precision: 0.9564 - val_recall: 0.9041
Epoch 17/100
561/561 [==============================] - 155s 276ms/step - loss: 0.1991 - accuracy: 0.9546 - precision: 0.9726 - recall: 0.9335 - val_loss: 0.2391 - val_accuracy: 0.9365 - val_precision: 0.9592 - val_recall: 0.9122
Epoch 18/100
561/561 [==============================] - 153s 273ms/step - loss: 0.1819 - accuracy: 0.9590 - precision: 0.9752 - recall: 0.9406 - val_loss: 0.2239 - val_accuracy: 0.9398 - val_precision: 0.9624 - val_recall: 0.9188
Epoch 19/100
561/561 [==============================] - 153s 274ms/step - loss: 0.1665 - accuracy: 0.9629 - precision: 0.9773 - recall: 0.9471 - val_loss: 0.2106 - val_accuracy: 0.9431 - val_precision: 0.9646 - val_recall: 0.9226
Epoch 20/100
561/561 [==============================] - 153s 272ms/step - loss: 0.1527 - accuracy: 0.9663 - precision: 0.9789 - recall: 0.9526 - val_loss: 0.1989 - val_accuracy: 0.9480 - val_precision: 0.9664 - val_recall: 0.9287
Epoch 21/100
561/561 [==============================] - 153s 272ms/step - loss: 0.1404 - accuracy: 0.9690 - precision: 0.9803 - recall: 0.9564 - val_loss: 0.1887 - val_accuracy: 0.9496 - val_precision: 0.9677 - val_recall: 0.9342
Epoch 22/100
561/561 [==============================] - 153s 273ms/step - loss: 0.1292 - accuracy: 0.9719 - precision: 0.9826 - recall: 0.9607 - val_loss: 0.1794 - val_accuracy: 0.9534 - val_precision: 0.9689 - val_recall: 0.9378
Epoch 23/100
561/561 [==============================] - 152s 272ms/step - loss: 0.1189 - accuracy: 0.9744 - precision: 0.9840 - recall: 0.9643 - val_loss: 0.1709 - val_accuracy: 0.9547 - val_precision: 0.9697 - val_recall: 0.9407
Epoch 24/100
561/561 [==============================] - 153s 272ms/step - loss: 0.1097 - accuracy: 0.9773 - precision: 0.9859 - recall: 0.9676 - val_loss: 0.1632 - val_accuracy: 0.9576 - val_precision: 0.9693 - val_recall: 0.9427
Epoch 25/100
561/561 [==============================] - 153s 273ms/step - loss: 0.1013 - accuracy: 0.9795 - precision: 0.9869 - recall: 0.9706 - val_loss: 0.1564 - val_accuracy: 0.9588 - val_precision: 0.9692 - val_recall: 0.9465
Epoch 26/100
561/561 [==============================] - 153s 273ms/step - loss: 0.0936 - accuracy: 0.9812 - precision: 0.9885 - recall: 0.9737 - val_loss: 0.1502 - val_accuracy: 0.9601 - val_precision: 0.9710 - val_recall: 0.9489
Epoch 27/100
561/561 [==============================] - 152s 272ms/step - loss: 0.0864 - accuracy: 0.9831 - precision: 0.9895 - recall: 0.9766 - val_loss: 0.1444 - val_accuracy: 0.9623 - val_precision: 0.9709 - val_recall: 0.9523
Epoch 28/100
561/561 [==============================] - 153s 273ms/step - loss: 0.0798 - accuracy: 0.9847 - precision: 0.9906 - recall: 0.9792 - val_loss: 0.1394 - val_accuracy: 0.9639 - val_precision: 0.9716 - val_recall: 0.9545
Epoch 29/100
561/561 [==============================] - 152s 271ms/step - loss: 0.0737 - accuracy: 0.9862 - precision: 0.9916 - recall: 0.9810 - val_loss: 0.1344 - val_accuracy: 0.9641 - val_precision: 0.9719 - val_recall: 0.9561
Epoch 30/100
561/561 [==============================] - 152s 271ms/step - loss: 0.0682 - accuracy: 0.9877 - precision: 0.9926 - recall: 0.9827 - val_loss: 0.1298 - val_accuracy: 0.9652 - val_precision: 0.9728 - val_recall: 0.9581
Epoch 31/100
561/561 [==============================] - 153s 272ms/step - loss: 0.0630 - accuracy: 0.9890 - precision: 0.9934 - recall: 0.9843 - val_loss: 0.1256 - val_accuracy: 0.9663 - val_precision: 0.9733 - val_recall: 0.9592
Epoch 32/100
561/561 [==============================] - 152s 272ms/step - loss: 0.0582 - accuracy: 0.9900 - precision: 0.9944 - recall: 0.9856 - val_loss: 0.1220 - val_accuracy: 0.9672 - val_precision: 0.9738 - val_recall: 0.9599
Epoch 33/100
561/561 [==============================] - 153s 273ms/step - loss: 0.0537 - accuracy: 0.9912 - precision: 0.9952 - recall: 0.9877 - val_loss: 0.1186 - val_accuracy: 0.9674 - val_precision: 0.9736 - val_recall: 0.9610
Epoch 34/100
561/561 [==============================] - 165s 295ms/step - loss: 0.0496 - accuracy: 0.9925 - precision: 0.9958 - recall: 0.9894 - val_loss: 0.1150 - val_accuracy: 0.9677 - val_precision: 0.9741 - val_recall: 0.9625
Epoch 35/100
561/561 [==============================] - 152s 271ms/step - loss: 0.0457 - accuracy: 0.9935 - precision: 0.9965 - recall: 0.9906 - val_loss: 0.1121 - val_accuracy: 0.9674 - val_precision: 0.9734 - val_recall: 0.9637
Epoch 36/100
561/561 [==============================] - 165s 294ms/step - loss: 0.0420 - accuracy: 0.9946 - precision: 0.9971 - recall: 0.9917 - val_loss: 0.1090 - val_accuracy: 0.9679 - val_precision: 0.9743 - val_recall: 0.9645
Epoch 37/100
353/561 [=================>............] - ETA: 45s - loss: 0.0403 - accuracy: 0.9944 - precision: 0.9974 - recall: 0.9923
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[46], line 1
----> 1 model.fit(dataset, epochs=100, validation_data=valdataset, callbacks=[tfbpard, model_check])

File /opt/conda/lib/python3.10/site-packages/keras/utils/traceback_utils.py:65, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     63 filtered_tb = None
     64 try:
---> 65     return fn(*args, **kwargs)
     66 except Exception as e:
     67     filtered_tb = _process_traceback_frames(e.__traceback__)

File /opt/conda/lib/python3.10/site-packages/keras/engine/training.py:1685, in Model.fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1677 with tf.profiler.experimental.Trace(
   1678     "train",
   1679     epoch_num=epoch,
   (...)
   1682     _r=1,
   1683 ):
   1684     callbacks.on_train_batch_begin(step)
-> 1685     tmp_logs = self.train_function(iterator)
   1686     if data_handler.should_sync:
   1687         context.async_wait()

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py:150, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    148 filtered_tb = None
    149 try:
--> 150   return fn(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:894, in Function.__call__(self, *args, **kwds)
    891 compiler = "xla" if self._jit_compile else "nonXla"
    893 with OptionalXlaContext(self._jit_compile):
--> 894   result = self._call(*args, **kwds)
    896 new_tracing_count = self.experimental_get_tracing_count()
    897 without_tracing = (tracing_count == new_tracing_count)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:926, in Function._call(self, *args, **kwds)
    923   self._lock.release()
    924   # In this case we have created variables on the first call, so we run the
    925   # defunned version which is guaranteed to never create variables.
--> 926   return self._no_variable_creation_fn(*args, **kwds)  # pylint: disable=not-callable
    927 elif self._variable_creation_fn is not None:
    928   # Release the lock early so that multiple threads can perform the call
    929   # in parallel.
    930   self._lock.release()

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py:143, in TracingCompiler.__call__(self, *args, **kwargs)
    140 with self._lock:
    141   (concrete_function,
    142    filtered_flat_args) = self._maybe_define_function(args, kwargs)
--> 143 return concrete_function._call_flat(
    144     filtered_flat_args, captured_inputs=concrete_function.captured_inputs)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py:1757, in ConcreteFunction._call_flat(self, args, captured_inputs, cancellation_manager)
   1753 possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
   1754 if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
   1755     and executing_eagerly):
   1756   # No tape is watching; skip to running the function.
-> 1757   return self._build_call_outputs(self._inference_function.call(
   1758       ctx, args, cancellation_manager=cancellation_manager))
   1759 forward_backward = self._select_forward_and_backward_functions(
   1760     args,
   1761     possible_gradient_type,
   1762     executing_eagerly)
   1763 forward_function, args_with_tangents = forward_backward.forward()

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py:381, in _EagerDefinedFunction.call(self, ctx, args, cancellation_manager)
    379 with _InterpolateFunctionError(self):
    380   if cancellation_manager is None:
--> 381     outputs = execute.execute(
    382         str(self.signature.name),
    383         num_outputs=self._num_outputs,
    384         inputs=args,
    385         attrs=attrs,
    386         ctx=ctx)
    387   else:
    388     outputs = execute.execute_with_cancellation(
    389         str(self.signature.name),
    390         num_outputs=self._num_outputs,
   (...)
    393         ctx=ctx,
    394         cancellation_manager=cancellation_manager)

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:52, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     50 try:
     51   ctx.ensure_initialized()
---> 52   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     53                                       inputs, attrs, num_outputs)
     54 except core._NotOkStatusException as e:
     55   if name is not None:

KeyboardInterrupt: 

2.6 Save and Plot History of Loss and Metrics¶

In [47]:
hist = pd.DataFrame(model.history.history)
In [48]:
hist
Out[48]:
loss accuracy precision recall val_loss val_accuracy val_precision val_recall
0 2.315288 0.179553 0.189744 0.002063 2.074300 0.270680 0.500000 0.000223
1 1.898178 0.372986 0.850000 0.011372 1.748079 0.444816 0.938931 0.027425
2 1.563593 0.527175 0.917114 0.085735 1.414445 0.578149 0.916121 0.155853
3 1.250982 0.646078 0.922441 0.242656 1.132928 0.677592 0.926224 0.333110
4 0.998160 0.726964 0.924227 0.416801 0.920478 0.746488 0.919617 0.492308
5 0.809847 0.784102 0.929505 0.563019 0.762211 0.796656 0.926550 0.613155
6 0.669706 0.826635 0.936275 0.665868 0.644828 0.826310 0.928055 0.687402
7 0.566022 0.853559 0.940414 0.736384 0.557613 0.845931 0.930447 0.742698
8 0.486354 0.876024 0.945256 0.786387 0.488742 0.865775 0.934435 0.781717
9 0.423297 0.892636 0.949572 0.821896 0.434986 0.880045 0.936602 0.813601
10 0.372941 0.907631 0.954531 0.849601 0.390868 0.893200 0.940720 0.838573
11 0.331124 0.920062 0.958272 0.870506 0.353982 0.904348 0.943721 0.856187
12 0.296109 0.928034 0.961413 0.887508 0.323386 0.912598 0.947610 0.875139
13 0.266226 0.936396 0.964892 0.902391 0.297228 0.922631 0.951196 0.886511
14 0.240728 0.943921 0.967329 0.914376 0.274513 0.927982 0.954912 0.897213
15 0.218527 0.950164 0.970585 0.925191 0.255608 0.932218 0.956368 0.904125
16 0.199123 0.954624 0.972643 0.933497 0.239109 0.936455 0.959203 0.912152
17 0.181903 0.959028 0.975205 0.940576 0.223875 0.939799 0.962401 0.918841
18 0.166533 0.962930 0.977335 0.947098 0.210584 0.943144 0.964569 0.922631
19 0.152725 0.966275 0.978862 0.952561 0.198898 0.948049 0.966357 0.928651
20 0.140375 0.969006 0.980287 0.956352 0.188701 0.949610 0.967667 0.934225
21 0.129151 0.971905 0.982554 0.960700 0.179404 0.953400 0.968901 0.937793
22 0.118920 0.974413 0.984015 0.964268 0.170938 0.954738 0.969662 0.940691
23 0.109728 0.977256 0.985857 0.967557 0.163216 0.957637 0.969280 0.942698
24 0.101337 0.979486 0.986851 0.970623 0.156414 0.958751 0.969178 0.946488
25 0.093560 0.981214 0.988456 0.973744 0.150175 0.960089 0.971024 0.948941
26 0.086403 0.983054 0.989495 0.976643 0.144434 0.962319 0.970903 0.952285
27 0.079827 0.984726 0.990638 0.979152 0.139403 0.963880 0.971630 0.954515
28 0.073720 0.986231 0.991604 0.980991 0.134449 0.964103 0.971895 0.956076
29 0.068197 0.987680 0.992623 0.982663 0.129831 0.965217 0.972832 0.958082
30 0.063044 0.988963 0.993418 0.984336 0.125594 0.966332 0.973303 0.959197
31 0.058202 0.990022 0.994376 0.985562 0.121992 0.967224 0.973762 0.959866
32 0.053703 0.991192 0.995226 0.987736 0.118572 0.967447 0.973571 0.960981
33 0.049615 0.992474 0.995792 0.989409 0.115030 0.967670 0.974052 0.962542
34 0.045664 0.993534 0.996523 0.990635 0.112055 0.967447 0.973423 0.963657
35 0.042015 0.994649 0.997141 0.991694 0.108998 0.967893 0.974324 0.964548
In [49]:
hist.to_csv("hist.csv")
In [50]:
hist[["accuracy", "val_accuracy"]].plot()
Out[50]:
<Axes: >
No description has been provided for this image
In [54]:
hist[['loss', 'val_loss']].plot()
Out[54]:
<Axes: >
No description has been provided for this image
In [55]:
hist[['recall', 'val_recall']].plot()
Out[55]:
<Axes: >
No description has been provided for this image

3. Test On Actual Data¶

3.1 Load the Best Model¶

In [114]:
bestm = tf.keras.models.load_model("model.h5")

3.2 Get Random Test Images¶

In [121]:
test_images = glob.glob("imgs/test/*/*")
In [122]:
def getidx(name):
    name = name.split("/")[-1]
    name = name.split("_")[1].split(".")[0]
    return int(name)
In [123]:
sorted_images = sorted(test_images, key=getidx)

3.3 Preprocess those and Predict¶

In [124]:
n_cases = 16
test_cases = np.random.choice(sorted_images, n_cases)
test_x = np.array([process_one_frameid(x) for x in test_cases])
prediction = bestm.predict(test_x)
y = [map_back[i] for i in prediction.argmax(-1)]
fig,ax = plt.subplots(n_cases, figsize=(5,5*n_cases))
for i in range(n_cases):
    ax[i].imshow(test_x[i])
    ax[i].set_title(y[i])
1/1 [==============================] - 0s 33ms/step
No description has been provided for this image
In [112]:
prediction.argmax(-1)
Out[112]:
array([5, 4, 2, 5, 2, 1, 6, 1, 1, 3, 7, 7, 3, 6, 2, 0])
In [ ]: